"""
Phase 8: Robustness
=====================
Table 14: Robustness matrix — Z₁ coefficient across 5 DVs × 7 subsamples

DVs: ca_gdp, gross_assets_gdp, gross_liab_gdp, income_balance_gdp, debt_assets_gdp
Specs: Full, OECD, non-OECD, excl financial centers, alt demographics, income terciles, oadr_plus20
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHL', 'COL', 'CRI', 'CZE', 'DNK', 'EST',
    'FIN', 'FRA', 'DEU', 'GRC', 'HUN', 'ISL', 'IRL', 'ISR', 'ITA', 'JPN',
    'KOR', 'LVA', 'LTU', 'LUX', 'MEX', 'NLD', 'NZL', 'NOR', 'POL', 'PRT',
    'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'TUR', 'GBR', 'USA',
]
FINANCIAL_CENTERS = ['LUX', 'IRL', 'HKG', 'SGP', 'CHE', 'NLD', 'BEL']

DEMO_VARS = ['Z_1', 'Z_2', 'Z_3']
EBA_CONTROLS = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw', 'kaopen']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def run_gls_z1(df, y_var, x_vars, quiet=True):
    """Run PanelGLS, return Z₁ coefficient, SE, p-value, N, R²."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        return None

    gls = PanelGLS()
    gls.fit(sub[y_var].values, sub[x_vars].values,
            sub['iso3'].values, sub['year'].values)

    z1_idx = x_vars.index('Z_1') if 'Z_1' in x_vars else (
        x_vars.index('old_dep') if 'old_dep' in x_vars else (
            x_vars.index('oadr_plus20') if 'oadr_plus20' in x_vars else None))

    if z1_idx is None:
        return None

    return {
        'coef': gls.beta[z1_idx],
        'se': gls.se[z1_idx],
        'p': gls.pvalues[z1_idx],
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }


def main():
    print("=" * 70)
    print("PHASE 8: ROBUSTNESS")
    print("=" * 70)

    df = pd.read_csv(DATA / "net_gross_panel.csv")
    df = df[df['year'] <= 2024].copy()
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    controls = [c for c in EBA_CONTROLS if c in df.columns and df[c].notna().sum() > 200]

    # Dependent variables
    dvs = ['ca_gdp', 'gross_assets_gdp', 'gross_liab_gdp',
           'income_balance_gdp', 'debt_assets_gdp']
    dvs = [v for v in dvs if v in df.columns]

    # Subsamples
    df_oecd = df[df['iso3'].isin(OECD)].copy()
    df_nonoecd = df[~df['iso3'].isin(OECD)].copy()
    df_nofc = df[~df['iso3'].isin(FINANCIAL_CENTERS)].copy()

    # Alt demographics: old_dep instead of Z
    alt_demo = ['old_dep', 'youth_dep']
    alt_demo = [v for v in alt_demo if v in df.columns]

    # Predetermined demographics
    has_oadr = 'oadr_plus20' in df.columns and df['oadr_plus20'].notna().sum() > 200

    # Income terciles
    has_tercile = 'income_tercile' in df.columns and df['income_tercile'].notna().sum() > 200

    # ══════════════════════════════════════════════════════════════════
    # TABLE 14: ROBUSTNESS MATRIX
    # ══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 50)
    print("TABLE 14: ROBUSTNESS MATRIX")
    print("=" * 50)

    specs = [
        ('Full', df, DEMO_VARS + controls),
        ('OECD', df_oecd, DEMO_VARS + controls),
        ('Non-OECD', df_nonoecd, DEMO_VARS + controls),
        ('Excl FC', df_nofc, DEMO_VARS + controls),
    ]

    if alt_demo:
        specs.append(('Alt demo', df, alt_demo + controls))

    if has_oadr:
        specs.append(('OADR+20', df, ['oadr_plus20'] + controls))

    if has_tercile:
        for tercile in ['Low', 'Mid', 'High']:
            sub = df[df['income_tercile'] == tercile].copy()
            if len(sub) > 100:
                specs.append((f'Inc: {tercile}', sub, DEMO_VARS + controls))

    # Build matrix
    matrix = []
    for spec_name, spec_df, spec_vars in specs:
        row = {'Specification': spec_name}
        for dv in dvs:
            r = run_gls_z1(spec_df, dv, spec_vars)
            if r:
                sig = stars(r['p'])
                row[dv] = f"{r['coef']:.2f}{sig}"
                row[f'{dv}_n'] = r['n_obs']
            else:
                row[dv] = '--'
                row[f'{dv}_n'] = 0
        matrix.append(row)
        dv_str = ' | '.join([f"{row.get(dv, '--'):>10s}" for dv in dvs])
        print(f"  {spec_name:15s}  {dv_str}")

    # Write table
    lines = ["# Table 14: Robustness Matrix — Z₁ Coefficient Across Specifications\n"]
    dv_labels = [dv.replace('_gdp', '') for dv in dvs]
    header = "| Specification | " + " | ".join(dv_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in dvs]) + "|"
    lines.append(header)
    lines.append(sep)

    for row in matrix:
        line = f"| {row['Specification']} |"
        for dv in dvs:
            line += f" {row.get(dv, '--')} |"
        lines.append(line)

    lines.append("\n*Each cell shows the Z₁ (or leading demographic proxy) coefficient.*")
    lines.append("*Panel GLS with AR(1) errors. Controls: " + ', '.join(controls) + ".*")
    lines.append("*Alt demo uses old_dep, youth_dep. OADR+20 uses predetermined dependency.*")
    lines.append("*Financial centers excluded: LUX, IRL, HKG, SGP, CHE, NLD, BEL.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    (OUT_TABLES / "robustness_matrix.md").write_text('\n'.join(lines))
    print(f"\n  Saved: robustness_matrix.md")

    print("\n" + "=" * 70)
    print("PHASE 8 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
